{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "source": [
    "# I - Consolidations of datasets"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "   timestamp_start  timestamp_end  qSQI_score  cSQI_score  sSQI_score  \\\n0         28800001       28803000        0.67        0.71        3.11   \n1         28803001       28806000        0.86        0.61        3.71   \n2         28806001       28809000        0.67        0.48        3.16   \n3         28809001       28812000        0.80        0.82        3.13   \n4         28812001       28815000        1.00        0.68        3.32   \n\n   kSQI_score  pSQI_score  basSQI_score  classif  classif_avg  patient  \n0       11.66        0.52          0.97        1          1.0   103001  \n1       15.86        0.53          0.96        1          1.0   103001  \n2       12.78        0.52          0.97        1          1.0   103001  \n3       11.77        0.50          0.94        1          1.0   103001  \n4       13.10        0.50          0.97        1          1.0   103001  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>timestamp_start</th>\n      <th>timestamp_end</th>\n      <th>qSQI_score</th>\n      <th>cSQI_score</th>\n      <th>sSQI_score</th>\n      <th>kSQI_score</th>\n      <th>pSQI_score</th>\n      <th>basSQI_score</th>\n      <th>classif</th>\n      <th>classif_avg</th>\n      <th>patient</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>28800001</td>\n      <td>28803000</td>\n      <td>0.67</td>\n      <td>0.71</td>\n      <td>3.11</td>\n      <td>11.66</td>\n      <td>0.52</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>28803001</td>\n      <td>28806000</td>\n      <td>0.86</td>\n      <td>0.61</td>\n      <td>3.71</td>\n      <td>15.86</td>\n      <td>0.53</td>\n      <td>0.96</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>28806001</td>\n      <td>28809000</td>\n      <td>0.67</td>\n      <td>0.48</td>\n      <td>3.16</td>\n      <td>12.78</td>\n      <td>0.52</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>28809001</td>\n      <td>28812000</td>\n      <td>0.80</td>\n      <td>0.82</td>\n      <td>3.13</td>\n      <td>11.77</td>\n      <td>0.50</td>\n      <td>0.94</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>28812001</td>\n      <td>28815000</td>\n      <td>1.00</td>\n      <td>0.68</td>\n      <td>3.32</td>\n      <td>13.10</td>\n      <td>0.50</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient count by classification:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "         classif\npatient         \n103001       800\n111001     30214\n113001      1200\n124001      1600",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>800</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>30214</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>1200</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>1600</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient average by classification:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "          classif  classif_avg\npatient                       \n103001   0.553750     0.723393\n111001   0.348779     0.419462\n113001   0.527500     0.661356\n124001   0.113750     0.286912",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n      <th>classif_avg</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>0.553750</td>\n      <td>0.723393</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>0.348779</td>\n      <td>0.419462</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>0.527500</td>\n      <td>0.661356</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>0.113750</td>\n      <td>0.286912</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    }
   ],
   "source": [
    "patients = [103001, 111001, 113001, 124001]\n",
    "\n",
    "df_ml_conso = pd.read_csv('../../datasets/2_dataset_creation_3s/df_ml_{}_norm.csv'.format(patients[0]))\n",
    "df_ml_conso['patient'] = patients[0]\n",
    "\n",
    "for patient in patients[1:]:\n",
    "    df_temp = pd.read_csv('../../datasets/2_dataset_creation_3s/df_ml_{}_norm.csv'.format(patient))\n",
    "    df_temp['patient'] = patient\n",
    "    df_ml_conso = pd.concat([df_ml_conso,df_temp], axis=0)\n",
    "\n",
    "display(df_ml_conso.head())\n",
    "\n",
    "print('\\nPatient count by classification:')\n",
    "display(pd.pivot_table(df_ml_conso, values=['classif'], index='patient',aggfunc='count'))\n",
    "\n",
    "print('\\nPatient average by classification:')\n",
    "display(pd.pivot_table(df_ml_conso, values=['classif', 'classif_avg'], index='patient',aggfunc=np.mean))"
   ]
  },
  {
   "source": [
    "# II - Appying threshold of quality"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient count:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "         classif\npatient         \n103001       800\n111001     30214\n113001      1200\n124001      1600",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>800</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>30214</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>1200</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>1600</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient average by classification:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "          classif  classif_avg  classif_threshold\npatient                                          \n103001   0.553750     0.723393           0.560000\n111001   0.348779     0.419462           0.353379\n113001   0.527500     0.661356           0.535000\n124001   0.113750     0.286912           0.125625",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n      <th>classif_avg</th>\n      <th>classif_threshold</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>0.553750</td>\n      <td>0.723393</td>\n      <td>0.560000</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>0.348779</td>\n      <td>0.419462</td>\n      <td>0.353379</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>0.527500</td>\n      <td>0.661356</td>\n      <td>0.535000</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>0.113750</td>\n      <td>0.286912</td>\n      <td>0.125625</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    }
   ],
   "source": [
    "# Setting a threshold : proportion of optimal by observation\n",
    "classif_threshold = 0.95\n",
    "\n",
    "df_ml_conso['classif_threshold'] = df_ml_conso['classif_avg'].apply(lambda x: 1 if x >= classif_threshold else 0)\n",
    "\n",
    "print('\\nPatient count:')\n",
    "display(pd.pivot_table(df_ml_conso, values=['classif'], index='patient',aggfunc='count'))\n",
    "\n",
    "print('\\nPatient average by classification:')\n",
    "display(pd.pivot_table(df_ml_conso, values=['classif', 'classif_avg', 'classif_threshold'], index='patient',aggfunc=np.mean))"
   ]
  },
  {
   "source": [
    "# III - Creation of a validation dataset"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ml_conso_for_model, df_ml_conso_validation = train_test_split(df_ml_conso, test_size=0.2, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "   timestamp_start  timestamp_end  qSQI_score  cSQI_score  sSQI_score  \\\n0         28800001       28803000        0.67        0.71        3.11   \n1         28803001       28806000        0.86        0.61        3.71   \n2         28806001       28809000        0.67        0.48        3.16   \n3         28809001       28812000        0.80        0.82        3.13   \n4         28812001       28815000        1.00        0.68        3.32   \n\n   kSQI_score  pSQI_score  basSQI_score  classif  classif_avg  patient  \\\n0       11.66        0.52          0.97        1          1.0   103001   \n1       15.86        0.53          0.96        1          1.0   103001   \n2       12.78        0.52          0.97        1          1.0   103001   \n3       11.77        0.50          0.94        1          1.0   103001   \n4       13.10        0.50          0.97        1          1.0   103001   \n\n   classif_threshold  \n0                  1  \n1                  1  \n2                  1  \n3                  1  \n4                  1  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>timestamp_start</th>\n      <th>timestamp_end</th>\n      <th>qSQI_score</th>\n      <th>cSQI_score</th>\n      <th>sSQI_score</th>\n      <th>kSQI_score</th>\n      <th>pSQI_score</th>\n      <th>basSQI_score</th>\n      <th>classif</th>\n      <th>classif_avg</th>\n      <th>patient</th>\n      <th>classif_threshold</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>28800001</td>\n      <td>28803000</td>\n      <td>0.67</td>\n      <td>0.71</td>\n      <td>3.11</td>\n      <td>11.66</td>\n      <td>0.52</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>28803001</td>\n      <td>28806000</td>\n      <td>0.86</td>\n      <td>0.61</td>\n      <td>3.71</td>\n      <td>15.86</td>\n      <td>0.53</td>\n      <td>0.96</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>28806001</td>\n      <td>28809000</td>\n      <td>0.67</td>\n      <td>0.48</td>\n      <td>3.16</td>\n      <td>12.78</td>\n      <td>0.52</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>28809001</td>\n      <td>28812000</td>\n      <td>0.80</td>\n      <td>0.82</td>\n      <td>3.13</td>\n      <td>11.77</td>\n      <td>0.50</td>\n      <td>0.94</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>28812001</td>\n      <td>28815000</td>\n      <td>1.00</td>\n      <td>0.68</td>\n      <td>3.32</td>\n      <td>13.10</td>\n      <td>0.50</td>\n      <td>0.97</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>103001</td>\n      <td>1</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Shape : (33814, 12)\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "       timestamp_start  timestamp_end  qSQI_score  cSQI_score  sSQI_score  \\\n15805         47416472       47419471        1.00        0.84        0.51   \n9587          28761638       28764638        0.33        0.61       -0.25   \n23768         71306119       71309118        0.67        1.27        5.33   \n6672          20016299       20019298        0.22        0.63        0.96   \n7929          23787377       23790376        0.33        0.52       -2.55   \n\n       kSQI_score  pSQI_score  basSQI_score  classif  classif_avg  patient  \\\n15805       -0.65        0.55          0.48        0          0.0   111001   \n9587        -1.73        0.50          0.49        0          0.0   111001   \n23768       36.67        0.49          0.91        1          1.0   111001   \n6672         3.78        0.56          0.90        0          0.0   111001   \n7929         6.28        0.53          0.73        0          0.0   111001   \n\n       classif_threshold  \n15805                  0  \n9587                   0  \n23768                  1  \n6672                   0  \n7929                   0  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>timestamp_start</th>\n      <th>timestamp_end</th>\n      <th>qSQI_score</th>\n      <th>cSQI_score</th>\n      <th>sSQI_score</th>\n      <th>kSQI_score</th>\n      <th>pSQI_score</th>\n      <th>basSQI_score</th>\n      <th>classif</th>\n      <th>classif_avg</th>\n      <th>patient</th>\n      <th>classif_threshold</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15805</th>\n      <td>47416472</td>\n      <td>47419471</td>\n      <td>1.00</td>\n      <td>0.84</td>\n      <td>0.51</td>\n      <td>-0.65</td>\n      <td>0.55</td>\n      <td>0.48</td>\n      <td>0</td>\n      <td>0.0</td>\n      <td>111001</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>9587</th>\n      <td>28761638</td>\n      <td>28764638</td>\n      <td>0.33</td>\n      <td>0.61</td>\n      <td>-0.25</td>\n      <td>-1.73</td>\n      <td>0.50</td>\n      <td>0.49</td>\n      <td>0</td>\n      <td>0.0</td>\n      <td>111001</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>23768</th>\n      <td>71306119</td>\n      <td>71309118</td>\n      <td>0.67</td>\n      <td>1.27</td>\n      <td>5.33</td>\n      <td>36.67</td>\n      <td>0.49</td>\n      <td>0.91</td>\n      <td>1</td>\n      <td>1.0</td>\n      <td>111001</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>6672</th>\n      <td>20016299</td>\n      <td>20019298</td>\n      <td>0.22</td>\n      <td>0.63</td>\n      <td>0.96</td>\n      <td>3.78</td>\n      <td>0.56</td>\n      <td>0.90</td>\n      <td>0</td>\n      <td>0.0</td>\n      <td>111001</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>7929</th>\n      <td>23787377</td>\n      <td>23790376</td>\n      <td>0.33</td>\n      <td>0.52</td>\n      <td>-2.55</td>\n      <td>6.28</td>\n      <td>0.53</td>\n      <td>0.73</td>\n      <td>0</td>\n      <td>0.0</td>\n      <td>111001</td>\n      <td>0</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Shape : (27051, 12)\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "       timestamp_start  timestamp_end  qSQI_score  cSQI_score  sSQI_score  \\\n15805         47416472       47419471        1.00        0.84        0.51   \n9587          28761638       28764638        0.33        0.61       -0.25   \n23768         71306119       71309118        0.67        1.27        5.33   \n6672          20016299       20019298        0.22        0.63        0.96   \n7929          23787377       23790376        0.33        0.52       -2.55   \n\n       kSQI_score  pSQI_score  basSQI_score  classification  \n15805       -0.65        0.55          0.48               0  \n9587        -1.73        0.50          0.49               0  \n23768       36.67        0.49          0.91               1  \n6672         3.78        0.56          0.90               0  \n7929         6.28        0.53          0.73               0  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>timestamp_start</th>\n      <th>timestamp_end</th>\n      <th>qSQI_score</th>\n      <th>cSQI_score</th>\n      <th>sSQI_score</th>\n      <th>kSQI_score</th>\n      <th>pSQI_score</th>\n      <th>basSQI_score</th>\n      <th>classification</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>15805</th>\n      <td>47416472</td>\n      <td>47419471</td>\n      <td>1.00</td>\n      <td>0.84</td>\n      <td>0.51</td>\n      <td>-0.65</td>\n      <td>0.55</td>\n      <td>0.48</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>9587</th>\n      <td>28761638</td>\n      <td>28764638</td>\n      <td>0.33</td>\n      <td>0.61</td>\n      <td>-0.25</td>\n      <td>-1.73</td>\n      <td>0.50</td>\n      <td>0.49</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>23768</th>\n      <td>71306119</td>\n      <td>71309118</td>\n      <td>0.67</td>\n      <td>1.27</td>\n      <td>5.33</td>\n      <td>36.67</td>\n      <td>0.49</td>\n      <td>0.91</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>6672</th>\n      <td>20016299</td>\n      <td>20019298</td>\n      <td>0.22</td>\n      <td>0.63</td>\n      <td>0.96</td>\n      <td>3.78</td>\n      <td>0.56</td>\n      <td>0.90</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>7929</th>\n      <td>23787377</td>\n      <td>23790376</td>\n      <td>0.33</td>\n      <td>0.52</td>\n      <td>-2.55</td>\n      <td>6.28</td>\n      <td>0.53</td>\n      <td>0.73</td>\n      <td>0</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Shape : (27051, 9)\n"
     ]
    }
   ],
   "source": [
    "df_ml_conso_validation = df_ml_conso_for_model.drop(['patient', 'classif', 'classif_avg'], axis=1)\n",
    "df_ml_conso_validation.rename(columns={df_ml_conso_validation.columns[-1]: \"classification\"}, inplace=True)\n",
    "\n",
    "for element in  [df_ml_conso, df_ml_conso_for_model, df_ml_conso_validation]:\n",
    "    display(element.head())\n",
    "    print('Shape : {}'.format(element.shape))\n",
    "\n",
    "df_ml_conso_validation.to_csv('../../datasets/3_ml_patients_consolidation/df_ml_conso_validation_norm_3s.csv', index=False)"
   ]
  },
  {
   "source": [
    "# IV - Equalisation of repartition by patient"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient count:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "         classif\npatient         \n103001       568\n111001     17100\n113001       900\n124001       342",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>568</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>17100</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>900</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>342</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nPatient average by classification:\n"
     ]
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "          classif  classif_avg  classif_threshold\npatient                                          \n103001   0.492958     0.690294                0.5\n111001   0.493509     0.552540                0.5\n113001   0.493333     0.636521                0.5\n124001   0.450292     0.593627                0.5",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>classif</th>\n      <th>classif_avg</th>\n      <th>classif_threshold</th>\n    </tr>\n    <tr>\n      <th>patient</th>\n      <th></th>\n      <th></th>\n      <th></th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>103001</th>\n      <td>0.492958</td>\n      <td>0.690294</td>\n      <td>0.5</td>\n    </tr>\n    <tr>\n      <th>111001</th>\n      <td>0.493509</td>\n      <td>0.552540</td>\n      <td>0.5</td>\n    </tr>\n    <tr>\n      <th>113001</th>\n      <td>0.493333</td>\n      <td>0.636521</td>\n      <td>0.5</td>\n    </tr>\n    <tr>\n      <th>124001</th>\n      <td>0.450292</td>\n      <td>0.593627</td>\n      <td>0.5</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "\nProportion of each class\n1    9455\n0    9455\nName: classif_threshold, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "df_ml_conso_balanced = pd.DataFrame()\n",
    "\n",
    "for patient in patients:\n",
    "    df_class1 = df_ml_conso_for_model[(df_ml_conso_for_model['patient'] == patient) & (df_ml_conso_for_model['classif_threshold'] ==1)]\n",
    "    df_class0 = df_ml_conso_for_model[(df_ml_conso_for_model['patient'] == patient) & (df_ml_conso_for_model['classif_threshold'] ==0)]\n",
    "\n",
    "    if df_class1.shape[0] >= df_class0.shape[0]:\n",
    "        df_ml_conso_balanced = pd.concat([df_ml_conso_balanced,\n",
    "                              df_class0,\n",
    "                              df_class1.sample(df_class0.shape[0])]\n",
    "                            )\n",
    "    else:\n",
    "        df_ml_conso_balanced = pd.concat([df_ml_conso_balanced,\n",
    "                        df_class0.sample(df_class1.shape[0]),\n",
    "                        df_class1]\n",
    "                    )\n",
    "\n",
    "print('\\nPatient count:')\n",
    "display(pd.pivot_table(df_ml_conso_balanced, values=['classif'], index='patient',aggfunc='count'))\n",
    "\n",
    "print('\\nPatient average by classification:')\n",
    "display(pd.pivot_table(df_ml_conso_balanced, values=['classif', 'classif_avg', 'classif_threshold'], index='patient',aggfunc=np.mean))\n",
    "\n",
    "print('\\nProportion of each class')\n",
    "print(df_ml_conso_balanced['classif_threshold'].value_counts())"
   ]
  },
  {
   "source": [
    "# V - Export"
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "     timestamp_start  timestamp_end  qSQI_score  cSQI_score  sSQI_score  \\\n741         58623175       58626174        0.50        0.63        3.57   \n640         58320141       58323140        0.80        0.88        4.19   \n431         57693034       57696034        0.86        0.54        4.27   \n512         57936084       57939084        1.00        0.71        3.97   \n647         58341143       58344142        0.75        0.52        4.27   \n\n     kSQI_score  pSQI_score  basSQI_score  classification  \n741       17.52        0.49          0.93               0  \n640       20.28        0.50          0.98               0  \n431       21.04        0.53          0.99               0  \n512       17.20        0.52          0.93               0  \n647       21.31        0.52          0.98               0  ",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>timestamp_start</th>\n      <th>timestamp_end</th>\n      <th>qSQI_score</th>\n      <th>cSQI_score</th>\n      <th>sSQI_score</th>\n      <th>kSQI_score</th>\n      <th>pSQI_score</th>\n      <th>basSQI_score</th>\n      <th>classification</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>741</th>\n      <td>58623175</td>\n      <td>58626174</td>\n      <td>0.50</td>\n      <td>0.63</td>\n      <td>3.57</td>\n      <td>17.52</td>\n      <td>0.49</td>\n      <td>0.93</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>640</th>\n      <td>58320141</td>\n      <td>58323140</td>\n      <td>0.80</td>\n      <td>0.88</td>\n      <td>4.19</td>\n      <td>20.28</td>\n      <td>0.50</td>\n      <td>0.98</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>431</th>\n      <td>57693034</td>\n      <td>57696034</td>\n      <td>0.86</td>\n      <td>0.54</td>\n      <td>4.27</td>\n      <td>21.04</td>\n      <td>0.53</td>\n      <td>0.99</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>512</th>\n      <td>57936084</td>\n      <td>57939084</td>\n      <td>1.00</td>\n      <td>0.71</td>\n      <td>3.97</td>\n      <td>17.20</td>\n      <td>0.52</td>\n      <td>0.93</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>647</th>\n      <td>58341143</td>\n      <td>58344142</td>\n      <td>0.75</td>\n      <td>0.52</td>\n      <td>4.27</td>\n      <td>21.31</td>\n      <td>0.52</td>\n      <td>0.98</td>\n      <td>0</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "metadata": {}
    }
   ],
   "source": [
    "df_ml_conso_balanced = df_ml_conso_balanced.drop(['patient', 'classif', 'classif_avg'], axis=1)\n",
    "df_ml_conso_balanced.rename(columns={df_ml_conso_balanced.columns[-1]: \"classification\"}, inplace=True)\n",
    "display(df_ml_conso_balanced.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_ml_conso_balanced.to_csv('../../datasets/3_ml_patients_consolidation/df_ml_conso_balanced_norm_3s.csv', index=False)"
   ]
  }
 ]
}